{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "86c0a1b5-41fd-47b9-9991-671205ab80c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 9912422/9912422 [00:11<00:00, 850260.77it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 28881/28881 [00:00<00:00, 126526.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1648877/1648877 [00:01<00:00, 1038065.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4542/4542 [00:00<00:00, 1881719.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# 设置随机种子（确保可复现性）\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# 1. 数据加载与预处理\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),          # 转为Tensor并归一化到[0,1]\n",
    "    transforms.Normalize((0.1307,), (0.3081,))  # MNIST的均值和标准差\n",
    "])\n",
    "\n",
    "# 下载并加载数据集\n",
    "train_dataset = torchvision.datasets.MNIST(\n",
    "    root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = torchvision.datasets.MNIST(\n",
    "    root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=64, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_dataset, batch_size=1000, shuffle=False)\n",
    "\n",
    "# 2. 定义CNN模型\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 输入通道1，输出通道32\n",
    "        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)\n",
    "        self.pool = nn.MaxPool2d(2, 2)  # 池化层（2x2窗口，步长2）\n",
    "        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # 全连接层（MNIST经过两次池化后为7x7）\n",
    "        self.fc2 = nn.Linear(128, 10)          # 输出10类（0-9）\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(torch.relu(self.conv1(x)))  # [64, 32, 14, 14]\n",
    "        x = self.pool(torch.relu(self.conv2(x)))  # [64, 64, 7, 7]\n",
    "        x = x.view(-1, 64 * 7 * 7)               # 展平\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "model = CNN()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# 3. 定义损失函数和优化器\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# 4. 训练模型\n",
    "def train(model, device, train_loader, optimizer, criterion, epochs=5):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for batch_idx, (data, target) in enumerate(train_loader):\n",
    "            data, target = data.t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "992588a1-7279-4678-bfda-4395caa87b3d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
